import numpy as np
from NetRate_batch import NetRate_batch
from utils import *

import matplotlib.pyplot as plt
import copy
import time

np.random.seed(233)

dist = "exp"
delta = 1

nd = 500

A, B = np.load("true_A.npz")['A'], np.load("true_B.npz")['B']

B_use = B.copy()
np.fill_diagonal(B_use, 0)
supp_A = (A > 0).astype(int)

num = 5

A_ini_mae_collect = np.zeros(num)
B_ini_mae_collect = np.zeros(num)
B_Acc_ini_collect = np.zeros(num)
B_Pre_ini_collect = np.zeros(num)
B_Recall_ini_collect = np.zeros(num)

A_mae_collect = np.zeros(num)
B_mae_collect = np.zeros(num)
B_Acc_collect = np.zeros(num)
B_Pre_collect = np.zeros(num)
B_Recall_collect = np.zeros(num)

nc = 50000
t = 10

for ii in range(num):
    cascades = np.load(dist + "_cascade_{}.npz".format(ii+1))['cascades']

    A_ini = np.maximum(np.random.randn(nd, nd), 0) * A
    B_ini = np.maximum(np.random.randn(nd, nd), 0) * B

    A_update = A_ini + B_ini
    A_update += np.ones_like(A) * 0.075
    np.fill_diagonal(A_update, 0)

    start = time.time()

    ############################################### MNCS EM ###############################################
    model = NetRate_batch(dist, delta, lr=1e-6, penl_l1=0.005, hard_thres=0.01, eps=0.000000001, batch_size=2000, max_Iter=20)
    res = model.optimize(A_update, cascades, t)
    #######################################################################################################

    A_update = res['A']; loss = res['loss']

    end = time.time()
    print("time:", end - start)

    loss = [tensor.detach().cpu().numpy() for tensor in loss]

    plt.plot(loss)
    plt.show()

    res_A = A_update * supp_A
    res_B = A_update * (1 - supp_A)

    np.savez_compressed("res_{}.npz".format(ii+1), A=res_A, B=res_B, loss=loss)

    np.fill_diagonal(res_B, 0)

    A_mae_collect[ii] = get_normalized_mae(res_A, A, hard_thres=0)
    B_mae_collect[ii] = get_normalized_mae(res_B, B_use, hard_thres=0)
    B_Acc_collect[ii] = get_acc(res_B, B_use, 0)
    B_Pre_collect[ii] = get_pre(res_B, B_use, 0)
    B_Recall_collect[ii] = get_recall(res_B, B_use, 0)

    print("experiment", ii, "finished\n")

print("Init Result:")
print("mean:", np.mean(A_ini_mae_collect), np.mean(B_ini_mae_collect), np.mean(B_Acc_ini_collect), np.mean(B_Pre_ini_collect), np.mean(B_Recall_ini_collect))
print("sd:", np.std(A_ini_mae_collect), np.std(B_ini_mae_collect), np.std(B_Acc_ini_collect), np.std(B_Pre_ini_collect), np.std(B_Recall_ini_collect), "\n")

print("NetRate Result:")
print("mean:", np.mean(A_mae_collect), np.mean(B_mae_collect), np.mean(B_Acc_collect), np.mean(B_Pre_collect), np.mean(B_Recall_collect))
print("sd:", np.std(A_mae_collect), np.std(B_mae_collect), np.std(B_Acc_collect), np.std(B_Pre_collect), np.std(B_Recall_collect))